Skip to content

Support group query attention in Attention(23) CUDA#27082

Open
Copilot wants to merge 26 commits intomainfrom
copilot/support-group-query-attention
Open

Support group query attention in Attention(23) CUDA#27082
Copilot wants to merge 26 commits intomainfrom
copilot/support-group-query-attention

Conversation

Copy link
Contributor

Copilot AI commented Jan 20, 2026

This pull request introduces improvements and bug fixes to the attention mechanism in ONNX Runtime, particularly focusing on the handling of attention masks and the computation of attention probabilities for both CPU and CUDA providers. The most significant changes include the addition of a new CUDA implementation for converting boolean attention masks to sequence lengths with validation, and several bug fixes in the CPU attention kernel to correctly handle head indices during computation.

CUDA Attention Mask Conversion and Validation:

  • Added a new CUDA implementation (attention_mask_impl.cu and attention_mask_impl.h) that efficiently converts a boolean attention mask to sequence lengths for GQA (Grouped Query Attention) kernels. This includes:
    • A CUDA kernel that processes each batch, validates that the mask starts with True and that padding is contiguous (right-padding only), and computes the correct sequence length per batch.
    • Support for 2D, 3D, and 4D mask shapes with proper broadcasting logic.
    • Error handling for masks that do not start with True or contain non-contiguous True/False values. [1] [2]

CPU Attention Kernel Bug Fixes:

  • Fixed bugs in the CPU attention kernel (attention.cc) by replacing incorrect uses of (head_i % parameters.kv_num_heads) and head_i with the correct head_ki and head_vi indices when accessing the K and V matrices. This ensures correct head alignment, especially in multi-head or grouped attention scenarios. [1] [2] [3] [4]

NOT supported in this PR

  • Cross attention (q_sequence_kength != kv_sequence_length)

  • 4d QKV (BNSH format)

  • is_causal=False

  • fp32

  • Softmax precision

  • qk_output_mode

Copilot AI and others added 4 commits January 20, 2026 21:56
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copilot AI changed the title [WIP] Implement support for group query attention in Attention(23) Support group query attention in Attention(23) CUDA Jan 20, 2026
Copilot AI requested a review from titaiwangms January 20, 2026 22:05
@titaiwangms
Copy link
Contributor

@copilot Can you enable the following three tests

TEST(AttentionTest, Attention3DGqaAttn) {

TEST(AttentionTest, Attention4DGqaAttnMask) {

TEST(AttentionTest, Attention4DGqaWithPastAndPresent) {

Just turn disable_cuda to false

Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot Can you enable the following three tests

TEST(AttentionTest, Attention3DGqaAttn) {

...

Enabled CUDA tests for all three GQA test cases in commit 2e10874:

  • Attention3DGqaAttn
  • Attention4DGqaAttnMask
  • Attention4DGqaWithPastAndPresent

@titaiwangms
Copy link
Contributor

@copilot
Also remove these disabled tests about GQA from the list:

"^test_attention_3d_gqa.*_cuda", // GQA not supported in Attention-cuda
"^test_attention_4d_gqa.*_cuda", // GQA not supported in Attention-cuda

Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot
Also remove these disabled tests about GQA from the list:
https://github.com/microsoft/onnxruntime/blob/fd21d0aa071edef70d715123b25e77891a63a5d9/onnxruntime/test/testdata/onnx_backend_test...

Removed GQA test filters from onnx_backend_test_series_filters.jsonc in commit b86acbd. The following lines were removed:

  • "^test_attention_3d_gqa.*_cuda"
  • "^test_attention_4d_gqa.*_cuda"

GQA tests will now run on CUDA.

@titaiwangms
Copy link
Contributor

@copilot

Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082

The error message:

LINK : warning LNK4075: ignoring '/INCREMENTAL' due to '/PROFILE' specification [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
     Creating library E:/_work/_temp/build/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_cuda.lib and object E:/_work/_temp/build/RelWithDebInfo/RelWithDebInfo/onnxruntime_providers_cuda.exp
attention.cc.obj : error LNK2019: unresolved external symbol "class onnxruntime::common::Status __cdecl onnxruntime::contrib::cuda::QkvToContext<float>(struct cudaDeviceProp const &,struct cublasContext * &,class onnxruntime::Stream *,struct onnxruntime::contrib::GroupQueryAttentionParameters &,struct onnxruntime::contrib::cuda::GroupQueryAttentionData<float> &)" (??$QkvToContext@M@cuda@contrib@onnxruntime@@YA?AVStatus@common@2@AEBUcudaDeviceProp@@AEAPEAUcublasContext@@PEAVStream@2@AEAUGroupQueryAttentionParameters@12@AEAU?$GroupQueryAttentionData@M@012@@Z) referenced in function "public: virtual class onnxruntime::common::Status __cdecl onnxruntime::cuda::Attention<float>::ComputeInternal(struct onnxruntime::OpKernelContext *)const " (?ComputeInternal@?$Attention@M@cuda@onnxruntime@@UEBA?AVStatus@common@3@PEAUOpKernelContext@3@@Z) [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
E:\_work\_temp\build\RelWithDebInfo\RelWithDebInfo\onnxruntime_providers_cuda.dll : fatal error LNK1120: 1 unresolved externals [E:\_work\_temp\build\RelWithDebInfo\onnxruntime_providers_cuda.vcxproj]
Traceback (most recent call last):
  File "E:\_work\onnxruntime\onnxruntime\tools\ci_build\build.py", line 2684, in <module>

Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot

Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...

Fixed CI linker error in commit 213a82d. The issue was that the GQA QkvToContext template function only had instantiations for half and BFloat16, but not for float. Added the missing float template instantiation along with GroupQueryAttentionData<float> and LaunchUnpackQKV<float, LAYOUT_BNSH> to group_query_attention_impl.cu.

@titaiwangms
Copy link
Contributor

@copilot
Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...

Fixed CI linker error in commit 213a82d. The issue was that the GQA QkvToContext template function only had instantiations for half and BFloat16, but not for float. Added the missing float template instantiation along with GroupQueryAttentionData<float> and LaunchUnpackQKV<float, LAYOUT_BNSH> to group_query_attention_impl.cu.

@copilot
Do not add float type. Just throw an error saying float is not supported with GQA. We are following GQA kernel.

REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Jan 20, 2026

@copilot
Please fix the CI https://github.com/microsoft/onnxruntime/actions/runs/21189348181/job/60952159971?pr=27082
...
...

Reverted float support for GQA in commit f79c509. Following GroupQueryAttention pattern:

  • Removed float template instantiations from group_query_attention_impl.cu
  • Added type validation in attention.cc to throw an error for float32 types when using GQA path
  • GQA now only supports float16 and bfloat16 types, matching the GroupQueryAttention operator behavior

@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Jan 21, 2026
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@titaiwangms titaiwangms marked this pull request as ready for review January 30, 2026 01:20
// Check if this is Group Query Attention (GQA)
const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads;

if (is_gqa) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Currently, we do not support 4D inputs of QKV.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added exeptions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The support requires kernel changes in FlashAttention and EfficientAttention. If we want to support 4d, the best way would be another cuda kernel to transpose/reshape the input from 4d to 3d before feeding it to those two attention kernels.

@titaiwangms titaiwangms requested a review from tianleiwu February 3, 2026 21:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants